## master preamble ------------------------------------------------------
rm(list=ls(all=TRUE)) 

## inherit directories from stata -----------
args = commandArgs(trailingOnly = "TRUE")
if (length(args)) {
  dir.proj = paste0(args[1],'/')
  dir.lib  = args[2]
} 

## if loading from library -------------------
if(dir.lib!='') {
  ## load depenencies --------------------
  library(R.utils,lib.loc=dir.lib)
  library(ggplot2,lib.loc=dir.lib)
  library(tidyr,lib.loc=dir.lib)
  library(stringr,lib.loc=dir.lib)
  library(forcats,lib.loc=dir.lib)
  library(lubridate,lib.loc=dir.lib)
  library(haven,lib.loc=dir.lib)
  library(data.table,lib.loc=dir.lib)
  library(backports,lib.loc=dir.lib)
  ## load main packages ------------------
  library(rio,lib.loc=dir.lib)
  library(dplyr,lib.loc=dir.lib)
  library(tidyverse,lib.loc=dir.lib)
  library(fixest,lib.loc=dir.lib)
  library(broom,lib.loc=dir.lib)
## otherwise: simple load -----------------
} else {
  library(rio)
  library(dplyr)
  library(tidyverse)
  library(fixest)
  library(broom)
}
## --------------------------------------------------------------------

## file preamble ------------------------------------------------------
## randomization seed --------
set.seed(413)

## directories --------
dir.data = paste0(dir.proj,'data/out/')
dir.est  = paste0(dir.proj,'estimates/')
dir.temp = paste0(dir.proj,'_temp/')

## file-specific -------
name.Y   = 'cite_ny1'
name.out = 'recid_cov'
bootrep  = 250

## functions ----------------
rudirichlet <- function(n, d) {
  rexp.mat <- matrix( rexp(d * n, 1) , nrow = n, ncol = d)
  dirichlet.weights <- rexp.mat / rowSums(rexp.mat)
  dirichlet.weights }
## ---------------------------------------------------------------------


## Setup data -------------------
data = import(paste0(dir.data,'4-main.dta'))
data = mutate(data,Y=data[[name.Y]],D=harsh,W=covbin1,jid=officerid) %>%
  select(citationid,Y,D,Z,W,totfe,jid,countynum,troop)

## G = heterogeneity group ------
data = mutate(
  data,
  G=as.numeric(as.factor(W)),
  totfe_G = paste(totfe,G,sep='_'))

## Polynomials in Z ---
data = mutate(
  data,
  Z1=Z^1,Z2=Z^2,Z3=Z^3,Z4=Z^4,Z5=Z^5,Z6=Z^6,Z7=Z^7,Z8=Z^8)


## Setup for clustered bootstrap -------------------
list    = unique(select(data,jid))
weights = rudirichlet(nrow(list),bootrep)*10


## Looping -----------
for(j in 0:bootrep) {
  
  ## Setup bootstrap data --------
  if(j==0) {
    boot.list = list %>% mutate(bootwt = 1)
    boot.data = inner_join(data,boot.list,by='jid') }
  if(j>=1) {
    boot.list = list %>% mutate(bootwt = weights[,j])
    boot.data = inner_join(data,boot.list,by='jid') }
  
  ## Get sample means ------------
  mu = boot.data %>% group_by(G) %>% summarize(
    Y = weighted.mean(Y,bootwt),
    p = weighted.mean(D,bootwt))
  
  ## Get other pars ------------
  fit  = feols(Y~0+D*i(G)-D,boot.data,weights=~bootwt)
  coef = broom::tidy(fit) %>%
    separate(term,sep='::',into=c('stub1','stub2')) %>%
    separate(stub2,sep=':',into=c('G','D'),convert=T,fill='right')
  mud = inner_join(
    coef %>% filter(is.na(D)) %>% select(G,y0d0=estimate),
    coef %>% filter(!is.na(D)) %>% select(G,delta=estimate),by='G') %>% mutate(y1d1=y0d0+delta) %>%
    select(G,y0d0,y1d1)
  
  ## Get group weights -----------
  wt = boot.data %>% group_by(G) %>% summarize(N=sum(bootwt)) %>%
    mutate(wt = N/sum(N)) %>% select(c(G,wt))
  
  ## Loop over polynomial order ------
  q.list = c(2)
  for(q in q.list) {
    
    ## Polynomial fit -----
    fit = fixest::feols(
      as.formula(glue::glue('Y~i(G)*(',paste(paste0('Z',seq(1:q)),collapse='+'),')',
                            '-i(G)-(',paste(paste0('Z',seq(1:q)),collapse='+'),')|totfe_G')),
      boot.data,weights=~bootwt)
    
    ## Means of FE --------
    fe = bind_cols(
      boot.data %>% select(G),
      stats::predict(fit,newdata=boot.data,fixef=T)) %>%
      group_by(G) %>% summarize(fe=mean(totfe_G,na.rm=T))
    
    ## Coefficients -------
    coef = broom::tidy(fit) %>%
      mutate(temp = stringr::str_remove(term,'G::')) %>%
      separate(temp,sep=':Z',into=c('G','q')) %>%
      mutate(G = as.numeric(G)) %>% 
      group_by(G) %>% summarize(sum=sum(estimate))
    
    ## Compile stuff --------
    compile = left_join(wt,mu,by='G') %>% left_join(mud,by='G') %>% 
      left_join(fe,by='G') %>% left_join(coef,by='G') %>% 
      mutate(
        y0=fe,y1=fe+sum,Y0=y0,Y1=y1,ATE=Y1-Y0,ATT=(Y-Y0)/p,ATU=(Y1-Y)/(1-p),
        Y1D1=y1d1,Y1D0=(1/(1-p))*y1-(p/(1-p))*y1d1,
        Y0D0=y0d0,Y0D1=(1/p)*y0-((1-p)/p)*y0d0,
        Y0D1mY0D0 = Y0D1-Y0D0, ATTmATU=ATT-ATU)
    
    ## Reshape --------------
    est.groups = compile %>% 
      select(G,Y,p,Y0,Y1,ATE,ATT,ATU,Y1D1,Y1D0,Y0D0,Y0D1,Y0D1mY0D0,ATTmATU) %>%
      pivot_longer(!G,names_to='par',values_to='est')
    
    ## Aggregate up ------
    est = left_join(est.groups,wt,by='G') %>% group_by(par) %>% 
      summarize(est=weighted.mean(est,wt)) %>%
      mutate(iter=j,poly=q)
    
    ## Post output ------------
    if(j==0 & q==q.list[1]) {
      boot.out = est }
    else { boot.out = bind_rows(boot.out,est) }
  }
  
  ## Progress report -----
  print(paste('Iteration',j,'of',bootrep,'Completed'))
  
}
## End of bootstrap -------------------------



## Bootstrap output -----------
est.main = boot.out %>% filter(iter==0) %>% select(poly,par,est)
est.boot = boot.out %>% filter(iter>=1) %>%
  group_by(poly,par) %>% summarize(se=sd(est),lq=quantile(est,0.05),uq=quantile(est,0.95))
est.out  = inner_join(est.main,est.boot) %>% mutate(lb=est-1.96*se,ub=est+1.96*se)
export(est.out,paste0(dir.est,name.out,'.csv'))

